/*
 * Copyright (c) 2021 Raspberry Pi (Trading) Ltd.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include <array>
#include <algorithm>
#include <sstream>
#include <iomanip>
#include <iostream>
#include "output_format.h"
#include "pio_disassembler.h"

struct python_output : public output_format {
    struct factory {
        factory() {
            output_format::add(new python_output());
        }
    };

    python_output() : output_format("python") {}

    std::string get_description() override {
        return "Python file suitable for use with MicroPython";
    }

    void output_symbols(FILE *out, std::string prefix, const std::vector<compiled_source::symbol> &symbols) {
        int count = 0;
        for (const auto &s : symbols) {
            if (!s.is_label) {
                fprintf(out, "%s%s = %d\n", prefix.c_str(), s.name.c_str(), s.value);
                count++;
            }
        }
        if (count) {
            fprintf(out, "\n");
            count = 0;
        }
        for (const auto &s : symbols) {
            if (s.is_label) {
                fprintf(out, "%soffset_%s = %d\n", prefix.c_str(), s.name.c_str(), s.value);
                count++;
            }
        }
        if (count) {
            fprintf(out, "\n");
        }
    }

    void header(FILE *out, std::string msg) {
        std::string dashes = std::string(msg.length(), '-');
        fprintf(out, "# %s #\n", dashes.c_str());
        fprintf(out, "# %s #\n", msg.c_str());
        fprintf(out, "# %s #\n", dashes.c_str());
        fprintf(out, "\n");
    }

    int output(std::string destination, std::vector<std::string> output_options,
               const compiled_source &source) override {
        FILE *out = open_single_output(destination);
        if (!out) return 1;

        header(out, "This file is autogenerated by pioasm; do not edit!");

        fprintf(out, "import rp2\n");
        fprintf(out, "from machine import Pin");
        fprintf(out, "\n");

        output_symbols(out, "", source.global_symbols);

        for (const auto &program : source.programs) {
            header(out, program.name);

            std::string prefix = program.name + "_";

            output_symbols(out, prefix, program.symbols);

            int param_count = 0;
            auto write_opt = [&] (std::string name, std::string value) {
                if (param_count++) {
                    fprintf(out, ", ");
                }
                fprintf(out, "%s=%s", name.c_str(), value.c_str());
            };
            fprintf(out, "@rp2.asm_pio(");
            for(const auto &p : program.lang_opts) {
                if (p.first.size() >= name.size() && p.first.compare(0, name.size(), name) == 0) {
                    for (const auto &p2 : p.second) {
                        write_opt(p2.first, p2.second);
                    }
                }
            }
            fprintf(out, ")\n");
            fprintf(out, "def %s():\n", program.name.c_str());

            std::map<uint, std::string> jmp_labels;
            // for now just use numeric labels
            for (int i = 0; i < (int)program.instructions.size(); i++) {
                const auto &inst = (uint16_t)program.instructions[i];
                if (!(inst >> 13u)) {
                    // a jump
                    uint target = inst &0x1fu;
                    jmp_labels.insert(std::pair<uint,std::string>(target, std::to_string(target)));
                }
            }

            for (uint i = 0; i < (uint)program.instructions.size(); i++) {
                const auto &inst = program.instructions[i];
                if (i == (uint)program.wrap_target) {
                    fprintf(out, "    wrap_target()\n");
                }
                auto it = jmp_labels.find(i);
                if (it != jmp_labels.end()) {
                    fprintf(out, "    label(\"%s\")\n", it->second.c_str());
                }
                fprintf(out, "    %s # %d\n", disassemble(jmp_labels, inst, program.sideset_bits_including_opt.get(), program.sideset_opt).c_str(), i);
                if (i == (uint)program.wrap) {
                    fprintf(out, "    wrap()\n");
                }
            }
            fprintf(out, "\n");

            /*
            fprintf(out, "static inline pio_sm_config %sprogram_default_config(uint offset) {\n", prefix.c_str());
            fprintf(out, "    pio_sm_config c = pio_sm_default_config();\n");
            fprintf(out, "    sm_config_wrap(&c, offset + %swrap_target, offset + %swrap);\n", prefix.c_str(),
                    prefix.c_str());
            if (program.sideset_bits_including_opt.is_specified()) {
                fprintf(out, "    sm_config_sideset(&c, %d, %s, %s);\n", program.sideset_bits_including_opt.get(),
                        program.sideset_opt ? "true" : "false",
                        program.sideset_pindirs ? "true" : "false");
            }
            fprintf(out, "    return c;\n");
            fprintf(out, "}\n");
*/
            // todo maybe have some code blocks inside or outside here?
            for(const auto& o : program.code_blocks) {
                fprintf(out, "\n");
                if (o.first == name) {
                    for(const auto &contents : o.second) {
                        fprintf(out, "%s", contents.c_str());
                        fprintf(out, "\n");
                    }
                }
            }

            fprintf(out, "\n");
        }
        if (out != stdout) { fclose(out); }
        return 0;
    }

    static std::string disassemble(const std::map<uint, std::string>& jmp_labels, uint inst, uint sideset_bits_including_opt, bool sideset_opt) {
        std::stringstream ss;
        uint major = (inst >> 13u) & 0x7;
        uint arg1 = ((uint) inst >> 5u) & 0x7u;
        uint arg2 = (inst & 0x1fu) | ((inst & 0x10000) >> 11);
        std::string op_string;
        auto op = [&](const std::string &s) {
            op_string = s;
        };
        auto op_guts = [&](const std::string &s) {
            ss << std::left << std::setw(24) << (op_string + "(" + s + ")");
        };

        bool invalid = false;
        switch (major) {
            case 0b000: {
                static std::array<std::string, 8> conditions{"", "not_x", "x_dec", "not_y", "y_dec", "x_not_y", "pin",
                                                             "not_osre"};
                op("jmp");
                auto it = jmp_labels.find(arg2);
                std::string label = "?";
                if (it != jmp_labels.end()) {
                    label = it->second;
                }
                if (arg1)
                    op_guts(conditions[arg1] + ", \"" + label +"\"");
                else
                    op_guts("\"" + label + "\"");
                break;
            }
            case 0b001: {
                uint source = arg1 & 3u;
                std::string guts;
                switch (source) {
                    case 0b00:
                        guts = "gpio, " + std::to_string(arg2);
                        break;
                    case 0b01:
                        guts = "pin, " + std::to_string(arg2);
                        break;
                    case 0b10:
                        if (arg2 & 0x8u) {
                            invalid = true;
                        } else {
                            guts = "irq, ";
                            auto irq = std::to_string(arg2 & 7u);
                            if (arg2 & 0x10u) {
                                guts += "rel(" + irq + ")";
                            } else {
                                guts += irq;
                            }
                        }
                        break;
                    default:
                        invalid = true;
                        break;
                }
                if (!invalid) {
                    guts = ((arg1 & 4u) ? "1, " : "0, ") + guts;
                    op("wait");
                    op_guts(guts);
                }
                break;
            }
            case 0b010: {
                static std::array<std::string, 8> sources { "pins", "x", "y", "null", "", "status", "isr", "osr"};
                std::string source = sources[arg1];
                if (source.empty()) {
                    invalid = true;
                } else {
                    op("in_");
                    op_guts(source + ", " + std::to_string(arg2 ? arg2 : 32));
                }
                break;
            }
            case 0b011: {
                static std::array<std::string, 8> dests { "pins", "x", "y", "null", "pindirs", "pc", "isr", "exec"};
                op("out");
                op_guts(dests[arg1] + ", " + std::to_string(arg2 ? arg2 : 32));
                break;
            }
            case 0b100: {
                if (arg2) {
                    invalid = true;
                } else {
                    std::string guts = "";
                    if (arg1 & 4u) {
                        op("pull");
                        if (arg1 & 2u) guts = "ifempty, ";
                    } else {
                        op("push");
                        if (arg1 & 2u) guts = "iffull, ";
                    }
                    guts += ((arg1 & 0x1u) ? "block" : "noblock");
                    op_guts(guts);
                }
                break;
            }
            case 0b101: {
                static std::array<std::string, 8> dests { "pins", "x", "y", "", "exec", "pc", "isr", "osr"};
                static std::array<std::string, 8> sources { "pins", "x", "y", "null", "", "status", "isr", "osr"};
                std::string dest = dests[arg1];
                std::string source = sources[arg2 & 7u];
                uint operation = arg2 >> 3u;
                if (source.empty() || dest.empty() || operation == 3) {
                    invalid = true;
                } else {
                    if (dest == source && (arg1 == 1 || arg2 == 2) && operation == 0) {
                        op("nop");
                        op_guts("");
                    } else {
                        op("mov");
                        std::string guts = dest + ", ";
                        if (operation == 1) {
                            guts += "invert(";
                        } else if (operation == 2) {
                            guts += "reverse(";
                        }
                        guts += source;
                        if (operation == 1 || operation == 2) {
                            guts += ")";
                        }
                        op_guts(guts);
                    }
                }
                break;
            }
            case 0b110: {
                if ((arg1 & 0x4u) || (arg2 & 0x8u)) {
                    invalid = true;
                } else {
                    op("irq");
                    std::string guts;
                    if (arg1 & 0x2u) {
                        guts += "clear, ";
                    } else if (arg1 & 0x1u) {
                        guts += "block, ";
                    }
                    auto irq = std::to_string(arg2 & 7u);
                    if (arg2 & 0x10u) {
                        guts += "rel(" + irq + ")";
                    } else {
                        guts += irq;
                    }
                    op_guts(guts);
                }
                break;
            }
            case 0b111: {
                static std::array<std::string, 8> dests{"pins", "x", "y", "", "pindirs", "", "", ""};
                std::string dest = dests[arg1];
                if (dest.empty()) {
                    invalid = true;
                } else {
                    op("set");
                    op_guts(dests[arg1] + ", " + std::to_string(arg2));
                }
                break;
            }
        }
        if (invalid) {
            op("word");
            ss << std::hex;
            std::stringstream guts;
            guts << std::hex << std::showbase << std::setfill('0') << std::setw(4) << inst;
            op_guts(guts.str());
        }
        uint delay = ((uint) inst >> 8u) & 0x1f;
        ss << std::left << std::setw(9);
        if (sideset_bits_including_opt && (!sideset_opt || (delay & 0x10u))) {
            ss << (".side("+ std::to_string((delay & (sideset_opt ? 0xfu : 0x1fu)) >> (5u - sideset_bits_including_opt))+")");
        } else {
            ss << "";
        }
        delay &= ((1u << (5 - sideset_bits_including_opt)) - 1u);
        ss << std::left << std::setw(4) << (delay ? ("[" + std::to_string(delay) + "]") : "");
        return ss.str();
    }
};

static python_output::factory creator;
